{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4982eb97",
   "metadata": {},
   "source": [
    "## 流自定义生成器函数\n",
    "\n",
    "您可以在 LCEL 管道中使用生成器函数（即使用 yield 关键字且行为类似于迭代器的函数）。\n",
    "\n",
    "这些生成器的签名应该是 Iterator[Input] -> Iterator[Output] 。或者对于异步生成器： AsyncIterator[Input] -> AsyncIterator[Output] 。\n",
    "\n",
    "这些对于： - 实现自定义输出解析器 - 修改上一步的输出，同时保留流功能\n",
    "\n",
    "让我们为逗号分隔列表实现一个自定义输出解析器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7bc0cb29",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI, OpenAI\n",
    "openai_api_key = \"EMPTY\"\n",
    "openai_api_base = \"http://127.0.0.1:1234/v1\"\n",
    "model = ChatOpenAI(\n",
    "    openai_api_key=openai_api_key,\n",
    "    openai_api_base=openai_api_base,\n",
    "    temperature=0.3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b4d27e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Iterator, List\n",
    "\n",
    "from langchain.prompts.chat import ChatPromptTemplate\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(\n",
    "    \"响应以CSV的格式返回中文列表，不要返回其他内容。请输出与{transportation}类似的交通工具\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "72edc909",
   "metadata": {},
   "outputs": [],
   "source": [
    "str_chain = prompt | model | StrOutputParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "285f55a3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\"直升机, 热气球, 滑翔机, 飞艇, 火箭\"'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str_chain.invoke({\"transportation\":\"飞机\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "8273c002",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\"直升机, 热气球, 滑翔机, 飞艇, 火箭\""
     ]
    }
   ],
   "source": [
    "for chunk in str_chain.stream({\"transportation\":\"飞机\"}):\n",
    "    print(chunk, end=\"\", flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "024ac8ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 这是一个自定义解析器，用于拆分 llm 令牌的迭代器\n",
    "# 放入以逗号分隔的字符串列表中\n",
    "def split_into_list(input: Iterator[str]) -> Iterator[List[str]]:\n",
    "    # 保留部分输入，直到得到逗号\n",
    "    buffer = \"\"\n",
    "    for chunk in input:\n",
    "        # 将当前块添加到缓冲区\n",
    "        buffer += chunk\n",
    "        # 当缓冲区中有逗号时\n",
    "        while \",\" in buffer:\n",
    "            # 以逗号分割缓冲区\n",
    "            comma_index = buffer.index(\",\")\n",
    "            # 产生逗号之前的所有内容\n",
    "            yield [buffer[:comma_index].strip()]\n",
    "            # 保存其余部分以供下一次迭代使用\n",
    "            buffer = buffer[comma_index + 1 :]\n",
    "    # 产生最后一个块\n",
    "    yield [buffer.strip()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ec8533e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\"直升机']['热气球']['滑翔机']['无人机']['飞艇\"']"
     ]
    }
   ],
   "source": [
    "list_chain = str_chain | split_into_list\n",
    "for chunk in list_chain.stream({\"transportation\":\"飞机\"}):\n",
    "    print(chunk, end=\"\", flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8c6dea8",
   "metadata": {},
   "source": [
    "## 异步版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "670c77cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import AsyncIterator\n",
    "\n",
    "\n",
    "async def asplit_into_list(\n",
    "    input: AsyncIterator[str],\n",
    ") -> AsyncIterator[List[str]]:  # async def\n",
    "    buffer = \"\"\n",
    "    async for (\n",
    "        chunk\n",
    "    ) in input:  # `input` 是一个 `async_generator` 对象，所以使用 `async for`\n",
    "        buffer += chunk\n",
    "        while \",\" in buffer:\n",
    "            comma_index = buffer.index(\",\")\n",
    "            yield [buffer[:comma_index].strip()]\n",
    "            buffer = buffer[comma_index + 1 :]\n",
    "    yield [buffer.strip()]\n",
    "\n",
    "\n",
    "list_chain = str_chain | asplit_into_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4f38dc41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\"直升机']\n",
      "['无人机']\n",
      "['热气球']\n",
      "['滑翔机']\n",
      "['飞艇\"']\n"
     ]
    }
   ],
   "source": [
    "async for chunk in list_chain.astream({\"transportation\":\"飞机\"}):\n",
    "    print(chunk, flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "26bc9acf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['\"直升机', '热气球', '滑翔机', '飞艇', '火箭\"']"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await list_chain.ainvoke({\"transportation\":\"飞机\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b10a53ff",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
